# -*- coding:utf-8 -*-
"""
@Time : 2020-12-15 11:49
@Author: langengel
@Des: Redis 数据库初始化
"""

from aioredis import create_redis_pool, Redis
from config import Cache, Code
import redis


async def init_cache() -> Redis:
    """
    默认缓存
    :return: cache 连接池
    """
    # 创建连接池
    cache = await create_redis_pool(
        address=f"redis://{Cache['host']}:{Cache['port']}",
        db=Cache['db'],
        encoding='utf-8'
    )

    return cache


async def init_code() -> Redis:
    """
    验证码缓存
    :return: code 连接池
    """
    # 创建连接池
    code = await create_redis_pool(
        address=f"redis://{Code['host']}:{Code['port']}",
        db=Code['db'],
        encoding='utf-8'
    )
    return code



# 以下是同步
class RedisQueue:
    def __init__(self, name):
        pool = redis.ConnectionPool(**Cache)
        self.db = redis.Redis(connection_pool=pool)
        self.key = name
    ###################################################      字符串方法       ###################################################
    #加入缓存,存在会替换
    def addStr(self, values): #value可以为复杂的json
        return self.db.set(self.key, values)

    # 不存在则加入,否则不变
    def addStrNX(self,values):  # value可以为复杂的json
        return self.db.setnx(self.key,values)

    # 加入缓存,存在会替换,并加入过期时间
    def addStrEX(self,time,  values):  # value可以为复杂的json
        return self.db.setex(self.key,time, values)

    #获取缓存
    def getStr(self):
        return self.db.get(self.key)

    ###################################################      列表方法       ###################################################
    # 返回队列里面list元素的长度
    def listLen(self):

        return self.db.llen(self.key)

    # 添加新元素到队列的最右方
    def add2right(self, *values):

        self.db.rpush(self.key, *values) #rpush 如何没有列表会创建 rpushx不创建

    # 返回并删除队列里的第一个元素,如果队列为空返回的None
    def getAndPopFirst(self):

        item = self.db.lpop(self.key)
        return item

    # 获取列表对应索引的值
    def getItemIndex(self,index):
        indexValue = self.db.lindex(self.key, index)
        return indexValue

    #获取列表所有元素
    def getAllListItem(self):
        allItemList=self.db.lrange(self.key, 0, -1)
        return allItemList

    # 获取列表start-end的元素,切头切尾
    def getListItem(self,start,end):
        itemList = self.db.lrange(self.key, start,end)
        return itemList

    #删除列表中value
    def delValue(self,value):
        self.db.lrem(self.key, 1, value)

    # 删除列表中指定数量的value
    def delMoreValue(self,count, value):#count为0表示删除所有值为value的元素
        self.db.lrem(self.key, count, value)

    # 获取并只保留start-end的元素,切头切尾
    def saveStart2end(self, start, end):
        start2endList = self.db.ltrim(self.key, start, end)
        return start2endList

    ###################################################      集合方法       ###################################################
    #加入集合中
    def add2set(self,*values):
        self.db.sadd(self.key, *values)

    #查看元素是否存在集合中
    def memberExist(self,member):

        check=self.db.sismember(self.key,member)
        return check

    ###################################################      散列方法       ###################################################
    #添加
    def set2hash(self,key,value):

        self.db.hset(self.key, key,value)

    #不存在才添加
    def set2hashNX(self,key,value):

        self.db.hsetnx(self.key, key,value)

    # 批量添加
    def batchSet2hash(self, mapping):
        self.db.hmset(self.key, mapping)

    #获取key对应的value
    def getValue(self,key):
        value=self.db.hget(self.key, key) #key不存在会返回None
        return value

    # 批量获取多个key对应的value
    def batchGetValue(self, keyList):
        value = self.db.hmget(self.key, keyList) #key不存在会返回None ['a', 's', None, None]
        return value

    # 删除key-value
    def delKeyValue(self, *keys):
        self.db.hdel(self.key, *keys)


    #判断散列中是否存在
    def keyExist(self,key):
        check=self.db.hexists(self.key, key)
        return check

    # 获取所有key
    def getAllKey(self):
        check = self.db.hkeys(self.key)
        return check

    # 获取所有value
    def getAllValue(self):
        check = self.db.hvals(self.key)
        return check

    # 获取所有key-value
    def getAllKeyValue(self):
        check = self.db.hgetall(self.key)
        return check

    # key对应的value增长amount
    def increaseValue(self,key,amount): #针对值为数量使用 ,key不存在也可以直接加
        check = self.db.hincrby(self.key, key, amount)
        return check

    # 获取散列长度
    def getDicLen(self):
        check = self.db.hlen(self.key)
        return check

    ###################################################      通用方法       ###################################################
    # 查看过期时间
    def checkExpireTime(self):

        item = self.db.ttl(self.key)
        return item

    # 设置过期时间
    def setExpireTime(self,time):
        item = self.db.expire(self.key, time)
        return item

#另一个db，继承使用
class RedisStatistics(RedisQueue):
    def __init__(self, name):
        pool = redis.ConnectionPool(**Code)
        self.db = redis.Redis(connection_pool=pool)
        self.key = name

if __name__ == "__main__":
    pass
    # 使用方法
    # @router.get('/test_redis/key')
    # async def test_redis(req: Request):
    #     cache = await req.app.state.cache.set('k1', 'v1')
    #     print(cache)
    #     return 'ok'

    async def get_value():
        r = await init_code()
        value = await r.get('k1')
        print(f'k1: {value!r}')
        await r.close()


    get_value()
